3. Speaker Classification: That's what who said!¶

  • Can a transformer-based model be used to classify a given line to a speaker?
  • How do specific keywords and phrases contribute to the model's prediction of a character's speaking style?

Approach¶

  • filter dataset for top five characters
  • encode data using the DistilBertTokenizer
  • use pre-trained DistilBertForSequenceClassification model
  • use AdamW optimizer to train the model

Model predictions¶

In [15]:
test_sample_str = "Hi i'm Michael Scott"
test_sample = tokenizer.encode(test_sample_str, truncation=True, padding=True, return_tensors='pt')
test_sample = test_sample.to(device)
output = model(test_sample)
# visualize
import plotly.express as px
output = torch.softmax(output.logits, dim=1)
output = output.detach().cpu().numpy()
px.bar(x=sorted_speakers, y=output[0], title=f"Model evaluation - Test sentence: '{test_sample_str}'")
In [16]:
test_sample_str = "Good to see you!"
test_sample = tokenizer.encode(test_sample_str, truncation=True, padding=True, return_tensors='pt')
test_sample = test_sample.to(device)
output = model(test_sample)
# visualize
import plotly.express as px
output = torch.softmax(output.logits, dim=1)
output = output.detach().cpu().numpy()
px.bar(x=sorted_speakers, y=output[0], title=f"Model evaluation - Test sentence: '{test_sample_str}'")

Model evaluation¶

In [19]:
# create a confusion matrix evaluating the above class predictions (normalized)
cm = confusion_matrix(val_labels, val_predictions, normalize='true')
cm_rand = confusion_matrix(val_labels, val_predictions_random, normalize='true')


# plot the confusion matrix
fig, ax = plt.subplots(1, 2, figsize=(12, 4))

sns.heatmap(cm, cmap='rocket_r', annot=True, fmt='.2f', xticklabels=sorted_speakers, yticklabels=sorted_speakers, ax=ax[0])
sns.heatmap(cm_rand, cmap='rocket_r', annot=True, fmt='.2f', xticklabels=sorted_speakers, yticklabels=sorted_speakers, ax=ax[1])

ax[0].set_ylabel('Actual')
ax[0].set_xlabel('Predicted')
ax[1].set_xlabel('Predicted')
ax[1].set_ylabel('Actual')
ax[0].set_title('Model')
ax[1].set_title('Random')
plt.show()
In [34]:
# visualize the first prediction's explanation
shap.plots.text(shap_values)


[0]
outputs
LABEL_0
LABEL_1
LABEL_2
LABEL_3
LABEL_4


0.50.2-0.1-0.40.81.11.400base value00fLABEL_0(inputs)
inputs
0.0
0.0
okay
0.0
.
0.0
000000000000000000000base value00fLABEL_0(inputs)
inputs
0.0
0.0
okay
0.0
.
0.0
0.50.2-0.1-0.40.81.11.40.3230620.323062base value00fLABEL_1(inputs) -0.313 okay -0.01 .
inputs
0.0
-0.313
okay
-0.01
.
0.0
0.20.100.30.3230620.323062base value00fLABEL_1(inputs) -0.313 okay -0.01 .
inputs
0.0
-0.313
okay
-0.01
.
0.0
0.50.2-0.1-0.40.81.11.400base value0.3232910.323291fLABEL_2(inputs)0.162 okay 0.162 .
inputs
0.0
0.162
okay
0.162
.
0.0
0.20.100.300base value0.3232910.323291fLABEL_2(inputs)0.162 okay 0.162 .
inputs
0.0
0.162
okay
0.162
.
0.0
0.50.2-0.1-0.40.81.11.400base value00fLABEL_3(inputs)0.162 okay -0.162 .
inputs
0.0
0.162
okay
-0.162
.
0.0
0-0.10.100base value00fLABEL_3(inputs)0.162 okay -0.162 .
inputs
0.0
0.162
okay
-0.162
.
0.0
0.50.2-0.1-0.40.81.11.400base value00fLABEL_4(inputs)
inputs
0.0
0.0
okay
0.0
.
0.0
000000000000000000000base value00fLABEL_4(inputs)
inputs
0.0
0.0
okay
0.0
.
0.0


[1]
outputs
LABEL_0
LABEL_1
LABEL_2
LABEL_3
LABEL_4


0.50.2-0.1-0.40.81.11.400base value00fLABEL_0(inputs)0.049 got 0.046 jim 0.042 hal 0.039 t 0.032 fat 0.031 we 0.03 , 0.028 per 0.013 die 0.012 show 0.012 of 0.012 tour 0.005 past 0.005 of 0.005 your 0.004 you 0.001 i take -0.074 camera -0.073 jim -0.055 to the -0.054 . -0.025 ? -0.023 here -0.019 on -0.013 dun -0.008 dun -0.008 dies -0.005 why -0.005 don ' t -0.003 you -0.001 a
inputs
0.0
0.0
so
0.0
why
0.0
don
0.0
'
0.0
t
0.001 / 2
i take
-0.003
you
-0.019
on
-0.001
a
0.012
tour
0.005
of
0.005
past
-0.013
dun
0.013
die
0.0
winners
0.0
.
0.031
we
0.049
got
0.032
fat
0.046
jim
0.042
hal
0.028
per
0.039
t
-0.023
here
-0.054
.
-0.073
jim
0.03
,
-0.005
why
-0.005 / 3
don ' t
0.004
you
0.012
show
0.012
of
0.005
your
-0.008
dun
-0.008
dies
-0.055 / 2
to the
-0.074
camera
-0.025
?
0.0
0-0.1-0.2-0.30.10.20.300base value00fLABEL_0(inputs)0.049 got 0.046 jim 0.042 hal 0.039 t 0.032 fat 0.031 we 0.03 , 0.028 per 0.013 die 0.012 show 0.012 of 0.012 tour 0.005 past 0.005 of 0.005 your 0.004 you 0.001 i take -0.074 camera -0.073 jim -0.055 to the -0.054 . -0.025 ? -0.023 here -0.019 on -0.013 dun -0.008 dun -0.008 dies -0.005 why -0.005 don ' t -0.003 you -0.001 a
inputs
0.0
0.0
so
0.0
why
0.0
don
0.0
'
0.0
t
0.001 / 2
i take
-0.003
you
-0.019
on
-0.001
a
0.012
tour
0.005
of
0.005
past
-0.013
dun
0.013
die
0.0
winners
0.0
.
0.031
we
0.049
got
0.032
fat
0.046
jim
0.042
hal
0.028
per
0.039
t
-0.023
here
-0.054
.
-0.073
jim
0.03
,
-0.005
why
-0.005 / 3
don ' t
0.004
you
0.012
show
0.012
of
0.005
your
-0.008
dun
-0.008
dies
-0.055 / 2
to the
-0.074
camera
-0.025
?
0.0
0.50.2-0.1-0.40.81.11.40.3914810.391481base value00fLABEL_1(inputs)0.028 jim 0.023 hal 0.023 , 0.011 per 0.01 die 0.007 of your 0.003 on 0.003 t 0.002 . -0.054 dun -0.046 got -0.044 dies -0.042 fat -0.036 don -0.034 dun -0.025 why -0.024 winners -0.023 here -0.023 so -0.02 ? -0.017 you show -0.015 jim -0.014 you -0.012 past -0.012 tour -0.011 we -0.009 ' -0.009 to the camera -0.007 take -0.007 i -0.006 don ' t -0.005 why -0.004 t -0.002 -0.0 a -0.0 . -0.0 of
inputs
-0.002
-0.023
so
-0.025
why
-0.036
don
-0.009
'
-0.004
t
-0.007
i
-0.007
take
-0.014
you
0.003
on
-0.0
a
-0.012
tour
-0.0
of
-0.012
past
-0.034
dun
0.01
die
-0.024
winners
-0.0
.
-0.011
we
-0.046
got
-0.042
fat
-0.015
jim
0.023
hal
0.011
per
0.003
t
-0.023
here
0.002
.
0.028
jim
0.023
,
-0.005
why
-0.006 / 3
don ' t
-0.017 / 2
you show
0.007 / 2
of your
-0.054
dun
-0.044
dies
-0.009 / 3
to the camera
-0.02
?
0.0
0.20.10-0.10.30.40.50.3914810.391481base value00fLABEL_1(inputs)0.028 jim 0.023 hal 0.023 , 0.011 per 0.01 die 0.007 of your 0.003 on 0.003 t 0.002 . -0.054 dun -0.046 got -0.044 dies -0.042 fat -0.036 don -0.034 dun -0.025 why -0.024 winners -0.023 here -0.023 so -0.02 ? -0.017 you show -0.015 jim -0.014 you -0.012 past -0.012 tour -0.011 we -0.009 ' -0.009 to the camera -0.007 take -0.007 i -0.006 don ' t -0.005 why -0.004 t -0.002 -0.0 a -0.0 . -0.0 of
inputs
-0.002
-0.023
so
-0.025
why
-0.036
don
-0.009
'
-0.004
t
-0.007
i
-0.007
take
-0.014
you
0.003
on
-0.0
a
-0.012
tour
-0.0
of
-0.012
past
-0.034
dun
0.01
die
-0.024
winners
-0.0
.
-0.011
we
-0.046
got
-0.042
fat
-0.015
jim
0.023
hal
0.011
per
0.003
t
-0.023
here
0.002
.
0.028
jim
0.023
,
-0.005
why
-0.006 / 3
don ' t
-0.017 / 2
you show
0.007 / 2
of your
-0.054
dun
-0.044
dies
-0.009 / 3
to the camera
-0.02
?
0.0
0.50.2-0.1-0.40.81.11.400base value00fLABEL_2(inputs)0.112 tour 0.058 got 0.051 t 0.028 we 0.021 t 0.02 a 0.02 so 0.02 why 0.016 jim 0.013 ? 0.013 past 0.011 dundies 0.009 you 0.009 show 0.006 of your 0.003 of 0.002 to the camera 0.001 hal 0.001 . -0.102 fat -0.088 die -0.034 dun -0.031 take -0.03 winners -0.025 you -0.022 per -0.022 , -0.022 jim -0.019 i -0.006 ' -0.005 why don ' t -0.005 don -0.003 on
inputs
0.0
0.02
so
0.02
why
-0.005
don
-0.006
'
0.051
t
-0.019
i
-0.031
take
-0.025
you
-0.003
on
0.02
a
0.112
tour
0.003
of
0.013
past
-0.034
dun
-0.088
die
-0.03
winners
0.001
.
0.028
we
0.058
got
-0.102
fat
0.016
jim
0.001
hal
-0.022
per
0.021
t
0.0
here
0.0
.
-0.022
jim
-0.022
,
-0.005 / 4
why don ' t
0.009
you
0.009
show
0.006 / 2
of your
0.011 / 2
dundies
0.002 / 3
to the camera
0.013
?
0.0
-0-0.1-0.2-0.3-0.40.10.20.30.400base value00fLABEL_2(inputs)0.112 tour 0.058 got 0.051 t 0.028 we 0.021 t 0.02 a 0.02 so 0.02 why 0.016 jim 0.013 ? 0.013 past 0.011 dundies 0.009 you 0.009 show 0.006 of your 0.003 of 0.002 to the camera 0.001 hal 0.001 . -0.102 fat -0.088 die -0.034 dun -0.031 take -0.03 winners -0.025 you -0.022 per -0.022 , -0.022 jim -0.019 i -0.006 ' -0.005 why don ' t -0.005 don -0.003 on
inputs
0.0
0.02
so
0.02
why
-0.005
don
-0.006
'
0.051
t
-0.019
i
-0.031
take
-0.025
you
-0.003
on
0.02
a
0.112
tour
0.003
of
0.013
past
-0.034
dun
-0.088
die
-0.03
winners
0.001
.
0.028
we
0.058
got
-0.102
fat
0.016
jim
0.001
hal
-0.022
per
0.021
t
0.0
here
0.0
.
-0.022
jim
-0.022
,
-0.005 / 4
why don ' t
0.009
you
0.009
show
0.006 / 2
of your
0.011 / 2
dundies
0.002 / 3
to the camera
0.013
?
0.0
0.50.2-0.1-0.40.81.11.400base value0.9940920.994092fLABEL_3(inputs)0.175 fat 0.165 die 0.123 dun 0.097 winners 0.082 you 0.073 camera 0.07 here 0.063 to the 0.061 . 0.06 don 0.054 take 0.052 on 0.044 why 0.043 ' 0.042 i 0.04 . 0.04 jim 0.038 dundies 0.032 ? 0.03 don ' t 0.013 , 0.01 per 0.006 we 0.001 a 0.001 -0.103 hal -0.068 t -0.042 of -0.04 past -0.03 t -0.03 jim -0.027 got -0.024 show -0.02 of -0.01 why -0.009 you -0.008 so -0.007 your -0.006 tour
inputs
0.001
-0.008
so
-0.01
why
0.06
don
0.043
'
-0.068
t
0.042
i
0.054
take
0.082
you
0.052
on
0.001
a
-0.006
tour
-0.042
of
-0.04
past
0.123
dun
0.165
die
0.097
winners
0.04
.
0.006
we
-0.027
got
0.175
fat
-0.03
jim
-0.103
hal
0.01
per
-0.03
t
0.07
here
0.061
.
0.04
jim
0.013
,
0.044
why
0.03 / 3
don ' t
-0.009
you
-0.024
show
-0.02
of
-0.007
your
0.038 / 2
dundies
0.063 / 2
to the
0.073
camera
0.032
?
0.0
0.50.2-0.1-0.40.81.11.400base value0.9940920.994092fLABEL_3(inputs)0.175 fat 0.165 die 0.123 dun 0.097 winners 0.082 you 0.073 camera 0.07 here 0.063 to the 0.061 . 0.06 don 0.054 take 0.052 on 0.044 why 0.043 ' 0.042 i 0.04 . 0.04 jim 0.038 dundies 0.032 ? 0.03 don ' t 0.013 , 0.01 per 0.006 we 0.001 a 0.001 -0.103 hal -0.068 t -0.042 of -0.04 past -0.03 t -0.03 jim -0.027 got -0.024 show -0.02 of -0.01 why -0.009 you -0.008 so -0.007 your -0.006 tour
inputs
0.001
-0.008
so
-0.01
why
0.06
don
0.043
'
-0.068
t
0.042
i
0.054
take
0.082
you
0.052
on
0.001
a
-0.006
tour
-0.042
of
-0.04
past
0.123
dun
0.165
die
0.097
winners
0.04
.
0.006
we
-0.027
got
0.175
fat
-0.03
jim
-0.103
hal
0.01
per
-0.03
t
0.07
here
0.061
.
0.04
jim
0.013
,
0.044
why
0.03 / 3
don ' t
-0.009
you
-0.024
show
-0.02
of
-0.007
your
0.038 / 2
dundies
0.063 / 2
to the
0.073
camera
0.032
?
0.0
0.50.2-0.1-0.40.81.11.400base value00fLABEL_4(inputs)0.019 hal 0.019 to the camera 0.019 jim 0.014 past 0.014 of 0.013 a 0.01 jim 0.01 dun 0.01 dies 0.002 . 0.001 per 0.0 t 0.0 you show of your -0.041 tour -0.022 here -0.019 , -0.017 ? -0.016 why -0.01 fat -0.005 don ' t
inputs
0.0
0.0
so
0.0
why
0.0
don
0.0
'
0.0
t
0.0
i
0.0
take
0.0
you
0.0
on
0.013
a
-0.041
tour
0.014
of
0.014
past
0.0
dun
0.0
die
0.0
winners
0.0
.
0.0
we
0.0
got
-0.01
fat
0.01
jim
0.019
hal
0.001
per
0.0
t
-0.022
here
0.002
.
0.019
jim
-0.019
,
-0.016
why
-0.005 / 3
don ' t
0.0 / 4
you show of your
0.01
dun
0.01
dies
0.019 / 3
to the camera
-0.017
?
0.0
0-0.04-0.08-0.120.040.080.1200base value00fLABEL_4(inputs)0.019 hal 0.019 to the camera 0.019 jim 0.014 past 0.014 of 0.013 a 0.01 jim 0.01 dun 0.01 dies 0.002 . 0.001 per 0.0 t 0.0 you show of your -0.041 tour -0.022 here -0.019 , -0.017 ? -0.016 why -0.01 fat -0.005 don ' t
inputs
0.0
0.0
so
0.0
why
0.0
don
0.0
'
0.0
t
0.0
i
0.0
take
0.0
you
0.0
on
0.013
a
-0.041
tour
0.014
of
0.014
past
0.0
dun
0.0
die
0.0
winners
0.0
.
0.0
we
0.0
got
-0.01
fat
0.01
jim
0.019
hal
0.001
per
0.0
t
-0.022
here
0.002
.
0.019
jim
-0.019
,
-0.016
why
-0.005 / 3
don ' t
0.0 / 4
you show of your
0.01
dun
0.01
dies
0.019 / 3
to the camera
-0.017
?
0.0
In [23]:
explained_texts
Out[23]:
['So my looks have nothing to do with it?',
 'My personal favorite is the one he made for his condo association.',
 'Do not bring Shakespeare into this. How dare you play the bard card?',
 '[opens eyes wide in total surprise]',
 "Ah, well it's still very good. I bet I know someone who hasn't heard that joke... your daughter Emily. How's she doing?"]

Used texts:

'So my looks have nothing to do with it?'
'My personal favorite is the one he made for his condo association.'
'Do not bring Shakespeare into this. How dare you play the bard card?'
'Ah, well it's still very good. I bet I know someone who hasn't heard that joke... your daughter Emily. How's she doing?'

  • 'favourite' and 'condo' are words that rather Jim uses
  • 'emily' and 'daughter' are words that rather Pam uses